import os
import django
from django.db import transaction, connection
from django.utils import timezone
from celery import shared_task
from functools import wraps
from celery_client.app import app as celery_client

os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'zmm.settings')
django.setup()

from common.models import Tenant
from common.logging import PrinterLogger, get_task_logger
from load_planner.models import Task, TaskStatus


def with_tenant(celery_task):
    @wraps(celery_task)
    def decor(payload, *args, **kwargs):
        tenant_id = payload['tenant_id']
        tenant = Tenant.objects.get(id=tenant_id)

        prev_schemas = connection.schemas
        connection.set_schema(tenant.schema_name)

        with transaction.atomic():
            results = celery_task(payload, *args, **kwargs)
        connection.set_schema(*prev_schemas)

        return results
    return decor


def with_task_logger(celery_task):
    @wraps(celery_task)
    def decor(payload, *args, **kwargs):

        if payload.get('task_id'):
            instance = Task.objects.get(id=payload['task_id'])
            kwargs['logger'] = get_task_logger(instance)

        return celery_task(payload, *args, **kwargs)
    return decor


def retrieve_and_lock_task(id):
    return (
        Task
        .objects
        .select_for_update()
        .filter(id=id)
        .first()
    )


@shared_task(name='internal.do_submit')
@with_tenant
@with_task_logger
def on_task_submitted(payload, logger=PrinterLogger):
    print('----------------------------')
    print(f'Logger is {logger}')
    # 获取要计算的业务任务
    task = retrieve_and_lock_task(payload['task_id'])

    # 生成算法输入，并递交给算法后端处理
    payload = {
        'tenant_id': payload['tenant_id'],
        'task_id': task.id,
        'config': task.config,
        'input_dataset': task.input_dataset,
    }

    logger.info(f'提交任务请求体: {payload}')

    celery_client.send_task('zmm.do_compute', args=[payload])

    # 记录实际递交时间
    task.status = TaskStatus.SUBMITTED.value
    task.submitted_at = timezone.now()
    task.save()

    logger.info('任务已提交')


@shared_task(name='zmm.do_start')
@with_tenant
@with_task_logger
def on_task_started(payload, logger=PrinterLogger):
    # 更新相关的业务任务状态到"运行中"
    task = retrieve_and_lock_task(payload['task_id'])
    task.status = TaskStatus.RUNNING.value
    task.save()
    logger.info('任务已经开始运行')


@shared_task(name='zmm.do_finish')
@with_tenant
@with_task_logger
def on_task_finished(payload, logger=PrinterLogger):
    # 更新相关的业务任务状态到"已结束"并记录结果
    task = retrieve_and_lock_task(payload['task_id'])
    task.result_dataset = payload['result_dataset']
    task.status = TaskStatus.FINISHED.value
    task.finished_at = timezone.now()
    task.save()
    logger.info('任务已经结束')


@shared_task(name='zmm.do_fail')
@with_tenant
@with_task_logger
def on_task_failed(payload, logger=PrinterLogger):
    # 更新相关的业务任务状态到"已失败"
    task = retrieve_and_lock_task(payload['task_id'])
    task.status = TaskStatus.ERROR.value
    task.finished_at = timezone.now()
    task.save()
    logger.info('任务遇到未知错误')
